# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

#' @include expression.R
#' @include record-batch.R
#' @include table.R

arrow_dplyr_query <- function(.data) {
  # An arrow_dplyr_query is a container for an Arrow data object (Table,
  # RecordBatch, or Dataset) and the state of the user's dplyr query--things
  # like selected columns, filters, and group vars.

  # For most dplyr methods,
  # method.Table == method.RecordBatch == method.Dataset == method.arrow_dplyr_query
  # This works because the functions all pass .data through arrow_dplyr_query()
  if (inherits(.data, "arrow_dplyr_query")) {
    return(.data)
  }
  structure(
    list(
      .data = .data$clone(),
      # selected_columns is a named character vector:
      # * vector contents are the names of the columns in the data
      # * vector names are the names they should be in the end (i.e. this
      #   records any renaming)
      selected_columns = set_names(names(.data)),
      # filtered_rows will be a ComparisonExpression
      filtered_rows = TRUE,
      # group_by_vars is a character vector of columns (as renamed)
      # in the data. They will be kept when data is pulled into R.
      group_by_vars = character()
    ),
    class = "arrow_dplyr_query"
  )
}

#' @export
print.arrow_dplyr_query <- function(x, ...) {
  schm <- x$.data$schema
  cols <- x$selected_columns
  fields <- map_chr(cols, ~schm$GetFieldByName(.)$ToString())
  # Strip off the field names as they are in the dataset and add the renamed ones
  fields <- paste(names(cols), sub("^.*?: ", "", fields), sep = ": ", collapse = "\n")
  cat(class(x$.data)[1], " (query)\n", sep = "")
  cat(fields, "\n", sep = "")
  cat("\n")
  if (!isTRUE(x$filtered_rows)) {
    cat("* Filter: ", x$filtered_rows$ToString(), "\n", sep = "")
  }
  if (length(x$group_by_vars)) {
    cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "")
  }
  cat("See $.data for the source Arrow object\n")
  invisible(x)
}

# These are the names reflecting all select/rename, not what is in Arrow
#' @export
names.arrow_dplyr_query <- function(x) names(x$selected_columns)

#' @export
dim.arrow_dplyr_query <- function(x) {
  if (isTRUE(x$filtered)) {
    rows <- x$.data$num_rows
  } else {
    warning(
      "For arrow dplyr queries that call filter(), ",
      "dim() returns NA for the number of rows.\n",
      "Call collect() to pull data into R to access the number of rows.",
      call. = FALSE
    )
    rows <- NA_integer_
  }
  cols <- length(names(x))
  c(rows, cols)
}

# The following S3 methods are registered on load if dplyr is present
select.arrow_dplyr_query <- function(.data, ...) {
  column_select(arrow_dplyr_query(.data), !!!enquos(...))
}
select.Dataset <- select.Table <- select.RecordBatch <- select.arrow_dplyr_query

#' @importFrom tidyselect vars_rename
rename.arrow_dplyr_query <- function(.data, ...) {
  column_select(arrow_dplyr_query(.data), !!!enquos(...), .FUN = vars_rename)
}
rename.Dataset <- rename.Table <- rename.RecordBatch <- rename.arrow_dplyr_query

column_select <- function(.data, ..., .FUN = vars_select) {
  # .FUN is either tidyselect::vars_select or tidyselect::vars_rename
  # It operates on the names() of selected_columns, i.e. the column names
  # factoring in any renaming that may already have happened
  out <- .FUN(names(.data), !!!enquos(...))
  # Make sure that the resulting selected columns map back to the original data,
  # as in when there are multiple renaming steps
  .data$selected_columns <- set_names(.data$selected_columns[out], names(out))

  # If we've renamed columns, we need to project that renaming into other
  # query parameters we've collected
  renamed <- out[names(out) != out]
  if (length(renamed)) {
    # Massage group_by
    gbv <- .data$group_by_vars
    renamed_groups <- gbv %in% renamed
    gbv[renamed_groups] <- names(renamed)[match(gbv[renamed_groups], renamed)]
    .data$group_by_vars <- gbv
    # No need to massage filters because those contain references to Arrow objects
  }
  .data
}

filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) {
  # TODO something with the .preserve argument
  filts <- quos(...)
  if (length(filts) == 0) {
    # Nothing to do
    return(.data)
  }

  .data <- arrow_dplyr_query(.data)
  # The filter() method works by evaluating the filters to generate Expressions
  # with references to Arrays (if .data is Table/RecordBatch) or Fields (if
  # .data is a Dataset).
  dm <- filter_mask(.data)
  filters <- lapply(filts, function (f) {
    # This should yield an Expression as long as the filter function(s) are
    # implemented in Arrow.
    tryCatch(eval_tidy(f, dm), error = function(e) {
      # Look for the cases where bad input was given, i.e. this would fail
      # in regular dplyr anyway, and let those raise those as errors;
      # else, for things not supported by Arrow return a "try-error",
      # which we'll handle differently
      msg <- conditionMessage(e)
      if (grepl("object '.*'.not.found", msg)) {
        stop(e)
      }
      if (grepl('could not find function ".*"', msg)) {
        stop(e)
      }
      invisible(structure(msg, class = "try-error", condition = e))
    })
  })
  bad_filters <- map_lgl(filters, ~inherits(., "try-error"))
  if (any(bad_filters)) {
    bads <- oxford_paste(map_chr(filts, as_label)[bad_filters], quote = FALSE)
    if (query_on_dataset(.data)) {
      # Abort. We don't want to auto-collect if this is a Dataset because that
      # could blow up, too big.
      stop(
        "Filter expression not supported for Arrow Datasets: ", bads,
        "\nCall collect() first to pull data into R.",
        call. = FALSE
      )
    } else {
      # TODO: only show this in some debug mode?
      warning(
        "Filter expression not implemented in Arrow: ", bads, "; pulling data into R",
        immediate. = TRUE,
        call. = FALSE
      )
      # Set any valid filters first, then collect and then apply the invalid ones in R
      .data <- set_filters(.data, filters[!bad_filters])
      return(dplyr::filter(dplyr::collect(.data), !!!filts[bad_filters]))
    }
  }

  set_filters(.data, filters)
}
filter.Dataset <- filter.Table <- filter.RecordBatch <- filter.arrow_dplyr_query

# Create a data mask for evaluating a filter expression
filter_mask <- function(.data) {
  f_env <- env()

  # Insert functions/operators and field references
  # TODO: define functions in env once, outside of this function
  # filter_env <- env(parent = if (data_is_dataset) function_env1 else function_env2)
  if (query_on_dataset(.data)) {
    comp_func <- function(operator) {
      force(operator)
      function(e1, e2) make_expression(operator, e1, e2)
    }
    var_binder <- function(x) FieldExpression$create(x)
  } else {
    comp_func <- function(operator) {
      force(operator)
      function(e1, e2) array_expression(operator, e1, e2)
    }
    var_binder <- function(x) .data$.data[[x]]
  }

  # First add the functions
  func_names <- set_names(c(names(comparison_function_map), "&", "|", "%in%"))
  env_bind(f_env, !!!lapply(func_names, comp_func))
  # Then add the column references
  # Renaming is handled automatically by the named list
  env_bind(f_env, !!!lapply(.data$selected_columns, var_binder))
  new_data_mask(f_env)
}

set_filters <- function(.data, expressions) {
  # expressions is a list of Expressions. AND them together and set them on .data
  new_filter <- Reduce("&", expressions)
  if (isTRUE(.data$filtered_rows)) {
    # TRUE is default (i.e. no filter yet), so we don't need to & with it
    .data$filtered_rows <- new_filter
  } else {
    .data$filtered_rows <- .data$filtered_rows & new_filter
  }
  .data
}

collect.arrow_dplyr_query <- function(x, ...) {
  colnames <- x$selected_columns
  # Be sure to retain any group_by vars
  gv <- setdiff(dplyr::group_vars(x), names(colnames))
  if (length(gv)) {
    colnames <- c(colnames, set_names(gv))
  }

  # Pull only the selected rows and cols into R
  if (query_on_dataset(x)) {
    # See dataset.R for Dataset and Scanner(Builder) classes
    scanner_builder <- x$.data$NewScan()
    scanner_builder$UseThreads()
    scanner_builder$Project(colnames)
    if (!isTRUE(x$filtered_rows)) {
      scanner_builder$Filter(x$filtered_rows)
    }
    df <- as.data.frame(scanner_builder$Finish()$ToTable())
  } else {
    # This is a Table/RecordBatch. See record-batch.R for the [ method
    df <- as.data.frame(x$.data[x$filtered_rows, colnames])
  }
  # In case variables were renamed, apply those names
  names(df) <- names(colnames)

  # Preserve groupings, if present
  if (length(x$group_by_vars)) {
    df <- dplyr::grouped_df(df, dplyr::groups(x))
  }
  df
}
collect.Table <- as.data.frame.Table
collect.RecordBatch <- as.data.frame.RecordBatch
collect.Dataset <- function(x, ...) dplyr::collect(arrow_dplyr_query(x), ...)

#' @importFrom tidyselect vars_pull
pull.arrow_dplyr_query <- function(.data, var = -1) {
  .data <- arrow_dplyr_query(.data)
  var <- vars_pull(names(.data), !!enquo(var))
  .data$selected_columns <- set_names(.data$selected_columns[var], var)
  dplyr::collect(.data)[[1]]
}
pull.Dataset <- pull.Table <- pull.RecordBatch <- pull.arrow_dplyr_query

summarise.arrow_dplyr_query <- function(.data, ...) {
  .data <- arrow_dplyr_query(.data)
  if (query_on_dataset(.data)) {
    not_implemented_for_dataset("summarize()")
  }
  # Only retain the columns we need to do our aggregations
  vars_to_keep <- unique(c(
    unlist(lapply(quos(...), all.vars)), # vars referenced in summarise
    dplyr::group_vars(.data)             # vars needed for grouping
  ))
  .data <- dplyr::select(.data, vars_to_keep)
  # TODO: determine whether work can be pushed down to Arrow
  dplyr::summarise(dplyr::collect(.data), ...)
}
summarise.Dataset <- summarise.Table <- summarise.RecordBatch <- summarise.arrow_dplyr_query

group_by.arrow_dplyr_query <- function(.data, ..., add = FALSE) {
  .data <- arrow_dplyr_query(.data)
  .data$group_by_vars <- dplyr::group_by_prepare(.data, ..., add = add)$group_names
  .data
}
group_by.Dataset <- group_by.Table <- group_by.RecordBatch <- group_by.arrow_dplyr_query

groups.arrow_dplyr_query <- function(x) syms(dplyr::group_vars(x))
groups.Dataset <- groups.Table <- groups.RecordBatch <- function(x) NULL

group_vars.arrow_dplyr_query <- function(x) x$group_by_vars
group_vars.Dataset <- group_vars.Table <- group_vars.RecordBatch <- function(x) NULL

ungroup.arrow_dplyr_query <- function(x, ...) {
  x$group_by_vars <- character()
  x
}
ungroup.Dataset <- ungroup.Table <- ungroup.RecordBatch <- force

mutate.arrow_dplyr_query <- function(.data, ...) {
  .data <- arrow_dplyr_query(.data)
  if (query_on_dataset(.data)) {
    not_implemented_for_dataset("mutate()")
  }
  # TODO: see if we can defer evaluating the expressions and not collect here.
  # It's different from filters (as currently implemented) because the basic
  # vector transformation functions aren't yet implemented in Arrow C++.
  dplyr::mutate(dplyr::collect(.data), ...)
}
mutate.Dataset <- mutate.Table <- mutate.RecordBatch <- mutate.arrow_dplyr_query
# transmute() "just works" because it calls mutate() internally
# TODO: add transmute() that does what summarise() does (select only the vars we need)

arrange.arrow_dplyr_query <- function(.data, ...) {
  .data <- arrow_dplyr_query(.data)
  if (query_on_dataset(.data)) {
    not_implemented_for_dataset("arrange()")
  }

  dplyr::arrange(dplyr::collect(.data), ...)
}
arrange.Dataset <- arrange.Table <- arrange.RecordBatch <- arrange.arrow_dplyr_query

query_on_dataset <- function(x) inherits(x$.data, "Dataset")

not_implemented_for_dataset <- function(method) {
  stop(
    method, " is not currently implemented for Arrow Datasets. ",
    "Call collect() first to pull data into R.",
    call. = FALSE
  )
}
